/*
 * © 2017 AgNO3 Gmbh & Co. KG
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */
package jcifs.internal.smb2.nego;

import java.util.Date;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import jcifs.CIFSContext;
import jcifs.Configuration;
import jcifs.DialectVersion;
import jcifs.internal.CommonServerMessageBlock;
import jcifs.internal.SMBProtocolDecodingException;
import jcifs.internal.SmbNegotiationRequest;
import jcifs.internal.SmbNegotiationResponse;
import jcifs.internal.smb2.ServerMessageBlock2Response;
import jcifs.internal.smb2.Smb2Constants;
import jcifs.internal.smb2.io.Smb2ReadResponse;
import jcifs.internal.smb2.io.Smb2WriteRequest;
import jcifs.internal.util.SMBUtil;
import jcifs.util.Hexdump;
import jcifs.util.transport.Response;

/**
 * SMB2 Negotiate Protocol response message.
 *
 * This response contains the server's protocol capabilities,
 * security mode, and negotiated dialect version.
 *
 * @author mbechler
 */
public class Smb2NegotiateResponse extends ServerMessageBlock2Response implements SmbNegotiationResponse {

    private static final Logger log = LoggerFactory.getLogger(Smb2NegotiateResponse.class);

    private int securityMode;
    private int dialectRevision;
    private final byte[] serverGuid = new byte[16];
    private int capabilities;
    private int commonCapabilities;
    private int maxTransactSize;
    private int maxReadSize;
    private int maxWriteSize;
    private long systemTime;
    private long serverStartTime;
    private NegotiateContextResponse[] negotiateContexts;
    private byte[] securityBuffer;
    private DialectVersion selectedDialect;

    private boolean supportsEncryption;
    private int selectedCipher = -1;
    private int selectedPreauthHash = -1;

    /**
     * Constructs an SMB2 negotiate response with the given configuration.
     *
     * @param cfg the configuration for this response
     */
    public Smb2NegotiateResponse(final Configuration cfg) {
        super(cfg);
    }

    /**
     * {@inheritDoc}
     *
     * @see jcifs.internal.SmbNegotiationResponse#getInitialCredits()
     */
    @Override
    public int getInitialCredits() {
        return getCredit();
    }

    /**
     * Gets the SMB dialect revision selected by the server.
     *
     * @return the dialectRevision
     */
    public int getDialectRevision() {
        return this.dialectRevision;
    }

    /**
     * Gets the server GUID used for identification.
     *
     * @return the serverGuid
     */
    public byte[] getServerGuid() {
        return this.serverGuid;
    }

    /**
     * @return the selectedDialect
     */
    @Override
    public DialectVersion getSelectedDialect() {
        return this.selectedDialect;
    }

    /**
     * Gets the encryption cipher selected for SMB3 encryption.
     *
     * @return the selectedCipher
     */
    public int getSelectedCipher() {
        return this.selectedCipher;
    }

    /**
     * Gets the pre-authentication integrity hash algorithm selected for SMB 3.1.1.
     *
     * @return the selectedPreauthHash
     */
    public int getSelectedPreauthHash() {
        return this.selectedPreauthHash;
    }

    /**
     * Gets the capabilities returned by the server.
     *
     * @return the server returned capabilities
     */
    public final int getCapabilities() {
        return this.capabilities;
    }

    /**
     * Gets the common capabilities negotiated between client and server.
     *
     * @return the common/negotiated capabilities
     */
    public final int getCommonCapabilities() {
        return this.commonCapabilities;
    }

    /**
     * Gets the initial security blob for authentication negotiation.
     *
     * @return initial security blob
     */
    public byte[] getSecurityBlob() {
        return this.securityBuffer;
    }

    /**
     * Gets the maximum transaction size supported by the server.
     *
     * @return the maxTransactSize
     */
    public int getMaxTransactSize() {
        return this.maxTransactSize;
    }

    /**
     * {@inheritDoc}
     *
     * @see jcifs.internal.SmbNegotiationResponse#getTransactionBufferSize()
     */
    @Override
    public int getTransactionBufferSize() {
        return getMaxTransactSize();
    }

    /**
     * Gets the negotiate contexts from the SMB 3.1.1 negotiation response.
     *
     * @return the negotiateContexts
     */
    public NegotiateContextResponse[] getNegotiateContexts() {
        return this.negotiateContexts;
    }

    /**
     * Gets the server start time timestamp.
     *
     * @return the serverStartTime
     */
    public long getServerStartTime() {
        return this.serverStartTime;
    }

    /**
     * Gets the security mode flags from the server.
     *
     * @return the securityMode
     */
    public int getSecurityMode() {
        return this.securityMode;
    }

    /**
     * {@inheritDoc}
     *
     * @see jcifs.internal.SmbNegotiationResponse#haveCapabilitiy(int)
     */
    @Override
    public boolean haveCapabilitiy(final int cap) {
        return (this.commonCapabilities & cap) == cap;
    }

    /**
     * {@inheritDoc}
     *
     * @see jcifs.internal.SmbNegotiationResponse#isDFSSupported()
     */
    @Override
    public boolean isDFSSupported() {
        return !getConfig().isDfsDisabled() && haveCapabilitiy(Smb2Constants.SMB2_GLOBAL_CAP_DFS);
    }

    /**
     * Checks whether SMB3 encryption is supported by the server.
     *
     * @return whether SMB encryption is supported by the server
     */
    public boolean isEncryptionSupported() {
        return this.supportsEncryption;
    }

    /**
     * {@inheritDoc}
     *
     * @see jcifs.internal.SmbNegotiationResponse#canReuse(jcifs.CIFSContext, boolean)
     */
    @Override
    public boolean canReuse(final CIFSContext tc, final boolean forceSigning) {
        return getConfig().equals(tc.getConfig());
    }

    @Override
    public boolean isValid(final CIFSContext tc, final SmbNegotiationRequest req) {
        if (!isReceived() || getStatus() != 0) {
            return false;
        }

        if (req.isSigningEnforced() && !isSigningEnabled()) {
            log.debug("Signing is enforced but server does not allow it");
            return false;
        }

        if (getDialectRevision() == Smb2Constants.SMB2_DIALECT_ANY) {
            log.debug("Server returned ANY dialect");
            return false;
        }

        final Smb2NegotiateRequest r = (Smb2NegotiateRequest) req;

        DialectVersion selected = null;
        for (final DialectVersion dv : DialectVersion.values()) {
            if (!dv.isSMB2()) {
                continue;
            }
            if (dv.getDialect() == getDialectRevision()) {
                selected = dv;
            }
        }

        if (selected == null) {
            log.debug("Server returned an unknown dialect");
            return false;
        }

        if (!selected.atLeast(getConfig().getMinimumVersion()) || !selected.atMost(getConfig().getMaximumVersion())) {
            log.debug("Server selected an disallowed dialect version {} (min: {} max: {})", selected, getConfig().getMinimumVersion(),
                    getConfig().getMaximumVersion());
            return false;
        }
        this.selectedDialect = selected;

        // Filter out unsupported capabilities
        this.commonCapabilities = r.getCapabilities() & this.capabilities;

        if ((this.commonCapabilities & Smb2Constants.SMB2_GLOBAL_CAP_ENCRYPTION) != 0) {
            this.supportsEncryption = tc.getConfig().isEncryptionEnabled();
        }

        if (this.selectedDialect.atLeast(DialectVersion.SMB311) && !checkNegotiateContexts(r, this.commonCapabilities)) {
            return false;
        }

        final int maxBufferSize = tc.getConfig().getTransactionBufferSize();
        this.maxReadSize =
                Math.min(maxBufferSize - Smb2ReadResponse.OVERHEAD, Math.min(tc.getConfig().getReceiveBufferSize(), this.maxReadSize))
                        & ~0x7;
        this.maxWriteSize =
                Math.min(maxBufferSize - Smb2WriteRequest.OVERHEAD, Math.min(tc.getConfig().getSendBufferSize(), this.maxWriteSize)) & ~0x7;
        this.maxTransactSize = Math.min(maxBufferSize - 512, this.maxTransactSize) & ~0x7;

        return true;
    }

    private boolean checkNegotiateContexts(final Smb2NegotiateRequest req, final int caps) {
        if (this.negotiateContexts == null || this.negotiateContexts.length == 0) {
            log.debug("Response lacks negotiate contexts");
            return false;
        }

        boolean foundPreauth = false, foundEnc = false;
        for (final NegotiateContextResponse ncr : this.negotiateContexts) {
            if (ncr == null) {
                continue;
            }
            if (!foundEnc && ncr.getContextType() == EncryptionNegotiateContext.NEGO_CTX_ENC_TYPE) {
                foundEnc = true;
                final EncryptionNegotiateContext enc = (EncryptionNegotiateContext) ncr;
                if (!checkEncryptionContext(req, enc)) {
                    return false;
                }
                this.selectedCipher = enc.getCiphers()[0];
                this.supportsEncryption = true;
            } else if (ncr.getContextType() == EncryptionNegotiateContext.NEGO_CTX_ENC_TYPE) {
                log.debug("Multiple encryption negotiate contexts");
                return false;
            } else if (!foundPreauth && ncr.getContextType() == PreauthIntegrityNegotiateContext.NEGO_CTX_PREAUTH_TYPE) {
                foundPreauth = true;
                final PreauthIntegrityNegotiateContext pi = (PreauthIntegrityNegotiateContext) ncr;
                if (!checkPreauthContext(req, pi)) {
                    return false;
                }
                this.selectedPreauthHash = pi.getHashAlgos()[0];
            } else if (ncr.getContextType() == PreauthIntegrityNegotiateContext.NEGO_CTX_PREAUTH_TYPE) {
                log.debug("Multiple preauth negotiate contexts");
                return false;
            }
        }

        if (!foundPreauth) {
            log.error("Missing preauth negotiate context");
            return false;
        }
        if (!foundEnc && (caps & Smb2Constants.SMB2_GLOBAL_CAP_ENCRYPTION) != 0) {
            log.error("Missing encryption negotiate context");
            return false;
        }
        if (!foundEnc) {
            log.debug("No encryption support");
        }
        return true;
    }

    private static boolean checkPreauthContext(final Smb2NegotiateRequest req, final PreauthIntegrityNegotiateContext pc) {
        if (pc.getHashAlgos() == null || pc.getHashAlgos().length != 1) {
            log.error("Server returned no hash selection");
            return false;
        }

        PreauthIntegrityNegotiateContext rpc = null;
        for (final NegotiateContextRequest rnc : req.getNegotiateContexts()) {
            if (rnc instanceof PreauthIntegrityNegotiateContext) {
                rpc = (PreauthIntegrityNegotiateContext) rnc;
            }
        }
        if (rpc == null) {
            return false;
        }

        boolean valid = false;
        for (final int hash : rpc.getHashAlgos()) {
            if (hash == pc.getHashAlgos()[0]) {
                valid = true;
            }
        }
        if (!valid) {
            log.error("Server returned invalid hash selection");
            return false;
        }
        return true;
    }

    private static boolean checkEncryptionContext(final Smb2NegotiateRequest req, final EncryptionNegotiateContext ec) {
        if (ec.getCiphers() == null || ec.getCiphers().length != 1) {
            log.error("Server returned no cipher selection");
            return false;
        }

        EncryptionNegotiateContext rec = null;
        for (final NegotiateContextRequest rnc : req.getNegotiateContexts()) {
            if (rnc instanceof EncryptionNegotiateContext) {
                rec = (EncryptionNegotiateContext) rnc;
            }
        }
        if (rec == null) {
            return false;
        }

        boolean valid = false;
        for (final int cipher : rec.getCiphers()) {
            if (cipher == ec.getCiphers()[0]) {
                valid = true;
            }
        }
        if (!valid) {
            log.error("Server returned invalid cipher selection");
            return false;
        }
        return true;
    }

    /**
     * {@inheritDoc}
     *
     * @see jcifs.internal.SmbNegotiationResponse#getReceiveBufferSize()
     */
    @Override
    public int getReceiveBufferSize() {
        return this.maxReadSize;
    }

    /**
     * {@inheritDoc}
     *
     * @see jcifs.internal.SmbNegotiationResponse#getSendBufferSize()
     */
    @Override
    public int getSendBufferSize() {
        return this.maxWriteSize;
    }

    /**
     * {@inheritDoc}
     *
     * @see jcifs.internal.SmbNegotiationResponse#isSigningEnabled()
     */
    @Override
    public boolean isSigningEnabled() {
        return (this.securityMode & Smb2Constants.SMB2_NEGOTIATE_SIGNING_ENABLED) != 0;
    }

    /**
     * {@inheritDoc}
     *
     * @see jcifs.internal.SmbNegotiationResponse#isSigningRequired()
     */
    @Override
    public boolean isSigningRequired() {
        return (this.securityMode & Smb2Constants.SMB2_NEGOTIATE_SIGNING_REQUIRED) != 0;
    }

    /**
     * {@inheritDoc}
     *
     * @see jcifs.internal.SmbNegotiationResponse#isSigningNegotiated()
     */
    @Override
    public boolean isSigningNegotiated() {
        return isSigningRequired();
    }

    /**
     * {@inheritDoc}
     *
     * @see jcifs.internal.SmbNegotiationResponse#setupRequest(jcifs.internal.CommonServerMessageBlock)
     */
    @Override
    public void setupRequest(final CommonServerMessageBlock request) {
    }

    /**
     *
     * {@inheritDoc}
     *
     * @see jcifs.internal.SmbNegotiationResponse#setupResponse(jcifs.util.transport.Response)
     */
    @Override
    public void setupResponse(final Response resp) {
    }

    @Override
    protected int readBytesWireFormat(final byte[] buffer, int bufferIndex) throws SMBProtocolDecodingException {
        final int start = bufferIndex;

        // Validate minimum buffer size for SMB2 negotiate response
        if (buffer == null || buffer.length < bufferIndex + 65) {
            throw new SMBProtocolDecodingException("Buffer too small for SMB2 negotiate response (minimum 65 bytes required)");
        }

        final int structureSize = SMBUtil.readInt2(buffer, bufferIndex);
        if (structureSize != 65) {
            throw new SMBProtocolDecodingException("Structure size is not 65, got: " + structureSize);
        }

        this.securityMode = SMBUtil.readInt2(buffer, bufferIndex + 2);
        // Validate security mode flags
        if ((this.securityMode & ~(Smb2Constants.SMB2_NEGOTIATE_SIGNING_ENABLED | Smb2Constants.SMB2_NEGOTIATE_SIGNING_REQUIRED)) != 0) {
            log.warn("Server returned unknown security mode flags: 0x{}", Integer.toHexString(this.securityMode));
        }
        bufferIndex += 4;

        this.dialectRevision = SMBUtil.readInt2(buffer, bufferIndex);
        final int negotiateContextCount = SMBUtil.readInt2(buffer, bufferIndex + 2);

        // Validate negotiate context count - prevent excessive memory allocation
        if (negotiateContextCount < 0 || negotiateContextCount > 100) {
            throw new SMBProtocolDecodingException("Invalid negotiate context count: " + negotiateContextCount + " (must be 0-100)");
        }

        bufferIndex += 4;

        // Validate sufficient buffer space for server GUID and capabilities
        if (buffer.length < bufferIndex + 16 + 4 + 4 + 4 + 4) {
            throw new SMBProtocolDecodingException("Buffer too small for server GUID and capabilities section");
        }

        System.arraycopy(buffer, bufferIndex, this.serverGuid, 0, 16);
        bufferIndex += 16;

        this.capabilities = SMBUtil.readInt4(buffer, bufferIndex);
        bufferIndex += 4;

        this.maxTransactSize = SMBUtil.readInt4(buffer, bufferIndex);
        this.maxReadSize = SMBUtil.readInt4(buffer, bufferIndex + 4);
        this.maxWriteSize = SMBUtil.readInt4(buffer, bufferIndex + 8);

        // Validate reasonable buffer sizes to prevent resource exhaustion
        if (this.maxTransactSize < 0 || this.maxTransactSize > 16777216) { // 16MB max
            throw new SMBProtocolDecodingException("Invalid maxTransactSize: " + this.maxTransactSize + " (must be 0-16777216)");
        }
        if (this.maxReadSize < 0 || this.maxReadSize > 16777216) { // 16MB max
            throw new SMBProtocolDecodingException("Invalid maxReadSize: " + this.maxReadSize + " (must be 0-16777216)");
        }
        if (this.maxWriteSize < 0 || this.maxWriteSize > 16777216) { // 16MB max
            throw new SMBProtocolDecodingException("Invalid maxWriteSize: " + this.maxWriteSize + " (must be 0-16777216)");
        }

        bufferIndex += 12;

        // Validate sufficient buffer space for timestamps and offsets
        if (buffer.length < bufferIndex + 8 + 8 + 4 + 4) {
            throw new SMBProtocolDecodingException("Buffer too small for timestamps and offsets section");
        }

        this.systemTime = SMBUtil.readTime(buffer, bufferIndex);
        bufferIndex += 8;
        this.serverStartTime = SMBUtil.readTime(buffer, bufferIndex);
        bufferIndex += 8;

        final int securityBufferOffset = SMBUtil.readInt2(buffer, bufferIndex);
        final int securityBufferLength = SMBUtil.readInt2(buffer, bufferIndex + 2);
        bufferIndex += 4;

        final int negotiateContextOffset = SMBUtil.readInt4(buffer, bufferIndex);
        bufferIndex += 4;

        // Validate security buffer parameters
        if (securityBufferLength < 0 || securityBufferLength > 65536) { // 64KB max for security buffer
            throw new SMBProtocolDecodingException("Invalid security buffer length: " + securityBufferLength + " (must be 0-65536)");
        }
        if (securityBufferOffset < 0) {
            throw new SMBProtocolDecodingException("Invalid security buffer offset: " + securityBufferOffset + " (must be non-negative)");
        }

        final int hdrStart = getHeaderStart();
        if (securityBufferLength > 0) {
            // Validate that security buffer doesn't exceed available data
            if (hdrStart + securityBufferOffset < hdrStart || // Check for integer overflow
                    hdrStart + securityBufferOffset + securityBufferLength < 0 || // Check for integer overflow
                    hdrStart + securityBufferOffset + securityBufferLength > buffer.length) {
                throw new SMBProtocolDecodingException("Security buffer extends beyond available data (offset: " + securityBufferOffset
                        + ", length: " + securityBufferLength + ", buffer size: " + buffer.length + ")");
            }

            this.securityBuffer = new byte[securityBufferLength];
            System.arraycopy(buffer, hdrStart + securityBufferOffset, this.securityBuffer, 0, securityBufferLength);
            bufferIndex += securityBufferLength;
        }

        final int pad = (bufferIndex - hdrStart) % 8;
        bufferIndex += pad;

        if (this.dialectRevision == 0x0311 && negotiateContextOffset != 0 && negotiateContextCount != 0) {
            // Validate negotiate context offset
            if (negotiateContextOffset < 0) {
                throw new SMBProtocolDecodingException(
                        "Invalid negotiate context offset: " + negotiateContextOffset + " (must be non-negative)");
            }

            int ncpos = getHeaderStart() + negotiateContextOffset;

            // Validate that negotiate context data doesn't start beyond buffer
            if (ncpos < 0 || ncpos >= buffer.length) {
                throw new SMBProtocolDecodingException(
                        "Negotiate context offset points beyond buffer (offset: " + ncpos + ", buffer size: " + buffer.length + ")");
            }

            final NegotiateContextResponse[] contexts = new NegotiateContextResponse[negotiateContextCount];
            for (int i = 0; i < negotiateContextCount; i++) {
                // Validate sufficient buffer space for context header
                if (ncpos + 8 > buffer.length) {
                    throw new SMBProtocolDecodingException("Buffer too small for negotiate context header at position " + i);
                }

                final int type = SMBUtil.readInt2(buffer, ncpos);
                final int dataLen = SMBUtil.readInt2(buffer, ncpos + 2);

                // Validate context data length
                if (dataLen < 0 || dataLen > 1024) { // 1KB max per context
                    throw new SMBProtocolDecodingException(
                            "Invalid negotiate context data length: " + dataLen + " at position " + i + " (must be 0-1024)");
                }

                ncpos += 4;
                ncpos += 4; // Reserved

                // Validate that context data doesn't exceed buffer
                if (ncpos + dataLen > buffer.length) {
                    throw new SMBProtocolDecodingException("Negotiate context data extends beyond buffer at position " + i
                            + " (data start: " + ncpos + ", length: " + dataLen + ", buffer size: " + buffer.length + ")");
                }

                final NegotiateContextResponse ctx = createContext(type);
                if (ctx != null) {
                    try {
                        ctx.decode(buffer, ncpos, dataLen);
                        contexts[i] = ctx;
                    } catch (Exception e) {
                        throw new SMBProtocolDecodingException(
                                "Failed to decode negotiate context at position " + i + ": " + e.getMessage(), e);
                    }
                }
                ncpos += dataLen;
                if (i != negotiateContextCount - 1) {
                    int padding = pad8(ncpos);
                    if (ncpos + padding > buffer.length) {
                        throw new SMBProtocolDecodingException("Negotiate context padding extends beyond buffer at position " + i);
                    }
                    ncpos += padding;
                }
            }
            this.negotiateContexts = contexts;
            return Math.max(bufferIndex, ncpos) - start;
        }

        return bufferIndex - start;
    }

    /**
     * Creates a negotiate context response based on the context type.
     *
     * @param type the negotiate context type
     * @return the appropriate negotiate context response, or null if not recognized
     */
    protected static NegotiateContextResponse createContext(final int type) {
        switch (type) {
        case EncryptionNegotiateContext.NEGO_CTX_ENC_TYPE:
            return new EncryptionNegotiateContext();
        case PreauthIntegrityNegotiateContext.NEGO_CTX_PREAUTH_TYPE:
            return new PreauthIntegrityNegotiateContext();
        case CompressionNegotiateContext.NEGO_CTX_COMPRESSION_TYPE:
            return new CompressionNegotiateContext();
        }
        return null;
    }

    /**
     * {@inheritDoc}
     *
     * @see jcifs.internal.smb2.ServerMessageBlock2#writeBytesWireFormat(byte[], int)
     */
    @Override
    protected int writeBytesWireFormat(final byte[] dst, final int dstIndex) {
        return 0;
    }

    @Override
    public String toString() {
        return ("Smb2NegotiateResponse[" + super.toString() + ",dialectRevision=" + this.dialectRevision + ",securityMode=0x"
                + Hexdump.toHexString(this.securityMode, 1) + ",capabilities=0x" + Hexdump.toHexString(this.capabilities, 8)
                + ",serverTime=" + new Date(this.systemTime));
    }

}
